#
# BN-Inception network components
# Details are in https://arxiv.org/pdf/1502.03167v3.pdf
#

ConvBNReLULayer {outChannels, kernel, stride, pad, bnScale, bnTimeConst} = Sequential(
    ConvolutionalLayer{outChannels, kernel, init = 'heNormal', stride = stride, pad = pad, bias = false} :
    BatchNormalizationLayer{spatialRank = 2, normalizationTimeConstant = bnTimeConst, initialScale = bnScale} :
    ReLU
)

InceptionWithAvgPoolLayer {num1x1, num3x3r, num3x3, num3x3dblr, num3x3dbl, numPool, bnScale, bnTimeConst} = {
    apply(x) = {
        # 1x1 Convolution
        branch1x1 = ConvBNReLULayer{num1x1, (1:1), (1:1), true, bnScale, bnTimeConst}(x)

        # 3x3 Convolution
        branch3x3 = Sequential( 
            ConvBNReLULayer{num3x3r, (1:1), (1:1), true, bnScale, bnTimeConst} :
            ConvBNReLULayer{num3x3,  (3:3), (1:1), true, bnScale, bnTimeConst}
        ) (x)

        # Double 3x3 Convolution
        branch3x3dbl = Sequential(
            ConvBNReLULayer{num3x3dblr, (1:1), (1:1), true, bnScale, bnTimeConst} :
            ConvBNReLULayer{num3x3dbl,  (3:3), (1:1), true, bnScale, bnTimeConst} :
            ConvBNReLULayer{num3x3dbl,  (3:3), (1:1), true, bnScale, bnTimeConst}
        ) (x)

        # Average Pooling
        branch_pool = Sequential(
            AveragePoolingLayer{(3:3), stride = (1:1), pad = true} :
            ConvBNReLULayer{numPool, (1:1), (1:1), true, bnScale, bnTimeConst}
        ) (x)

        out = Splice((branch1x1:branch3x3:branch3x3dbl:branch_pool), axis=3)
    }.out
}.apply

InceptionWithMaxPoolLayer {num1x1, num3x3r, num3x3, num3x3dblr, num3x3dbl, numPool, bnScale, bnTimeConst} = {
    apply(x) = {
        # 1x1 Convolution
        branch1x1 = ConvBNReLULayer{num1x1, (1:1), (1:1), true, bnScale, bnTimeConst}(x)

        # 3x3 Convolution
        branch3x3 = Sequential( 
            ConvBNReLULayer{num3x3r, (1:1), (1:1), true, bnScale, bnTimeConst} :
            ConvBNReLULayer{num3x3,  (3:3), (1:1), true, bnScale, bnTimeConst}
        ) (x)

        # Double 3x3 Convolution
        branch3x3dbl = Sequential(
            ConvBNReLULayer{num3x3dblr, (1:1), (1:1), true, bnScale, bnTimeConst} :
            ConvBNReLULayer{num3x3dbl,  (3:3), (1:1), true, bnScale, bnTimeConst} :
            ConvBNReLULayer{num3x3dbl,  (3:3), (1:1), true, bnScale, bnTimeConst}
        ) (x)

        # Max Pooling
        branch_pool = Sequential(
            MaxPoolingLayer{(3:3), stride=(1:1), pad=true} :
            ConvBNReLULayer{numPool, (1:1), (1:1), true, bnScale, bnTimeConst}
        ) (x)

        out = Splice((branch1x1:branch3x3:branch3x3dbl:branch_pool), axis=3)
    }.out
}.apply

InceptionPassThroughLayer {num1x1, num3x3r, num3x3, num3x3dblr, num3x3dbl, numPool, bnScale, bnTimeConst} = {
    apply(x) = {
        # 3x3 Convolution
        branch3x3 = Sequential( 
            ConvBNReLULayer{num3x3r, (1:1), (1:1), true, bnScale, bnTimeConst} :
            ConvBNReLULayer{num3x3,  (3:3), (2:2), true, bnScale, bnTimeConst}
        ) (x)

        # Double 3x3 Convolution
        branch3x3dbl = Sequential(
            ConvBNReLULayer{num3x3dblr, (1:1), (1:1), true, bnScale, bnTimeConst} :
            ConvBNReLULayer{num3x3dbl,  (3:3), (1:1), true, bnScale, bnTimeConst} :
            ConvBNReLULayer{num3x3dbl,  (3:3), (2:2), true, bnScale, bnTimeConst}
        ) (x)
        
        # Max Pooling
        branch_pool = MaxPoolingLayer{(3:3), stride=(2:2), pad=true}(x)

        out = Splice((branch3x3:branch3x3dbl:branch_pool), axis=3)
    }.out
}.apply